from flask import request
import time
import os
import logging
from flask import Blueprint
from model.db import db
from model.mlmodel import Model
from model.task import Task
from algo.mlmodel.model_entry import kill_process, ml_entry
from utils.mlmodel_util import update_task_when_pred
from utils.model_util import save_model, model_type
from utils.database_util import read_gp, save_gp
from utils.format_util import py_to_java, cast_float

mlmodel = Blueprint('mlmodel', __name__)

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")


def get_model_by_id(model_id):
    try:
        model = Model.query.filter_by(id=model_id).first()
    except:
        db.session.rollback()
        logging.info("session rollback")
        model = Model.query.filter_by(id=model_id).first()
    return model

@mlmodel.route('/ml_model', methods=['POST'])
def ml_model():
    logging.info("enter ml_model")
    starttime = time.time()
    dict = request.form
    execution = dict.get("execution")
    model_id = dict.get("model_id")
    target = None
    task_id = None
    if execution == "predict":
        target = dict.get("target")
        task_id = dict.get("taskId")
        source = dict.get("source")
        feature_X = dict.get("feature_X")

    model = get_model_by_id(model_id)

    # record the pid TODO Question
    pid = os.getpid()
    model.progress_id = pid
    db.session.commit()

    algo_type = model.algorithm
    if execution == "kill":
        prev_pid = model.progress_id
        if prev_pid == 0:
            return "previous process already finished"
        else:
            kill_process(model.progress_id)
            return "previous process has been killed"
    else:

        # setting up the para_list
        para_list = eval(model.param)

        print(para_list)
        para_list["model_id"] = model_id
        para_list["execution"] = execution
        para_list["source"] = model.source_table

        # load the para_list for prediction
        if execution == "predict":
            if "." not in source:
                source = "dataset." + source
            para_list["source"] = source
            para_list["feature_X"] = feature_X

        # initializing output structure
        output = {"status": 500, "error": [], "message": [], "result": {}}
        output["result"]["input_params"] = para_list
        output["result"]["output_params"] = {}
        record = {}

        # read the dataframe
        cols = "*"
        if execution == "train" and model_type(algo_type) in ["classification", "regression"]:
            columns = set(para_list["feature_col"] + ["_record_id_"] + [para_list["label_col"]])
            cols = "\"" + "\",\"".join(columns)+ "\""
        sql = "select {} from {}".format(cols, para_list["source"])
        dataframe = read_gp(sql)
        # dataframe.dropna()
        # machine learning entry
        try:
            record, output, df_pred, mlmodel, train_inf = ml_entry(para_list, output, dataframe, record)
            logging.info("Algorithm Finished")
        except Exception as e:
            logging.error("error occurs in training/predicting")
            #record["other_info"] = {"error": py_to_java(str(e))}
            output["error"] = py_to_java(str(e))
        finally:
            logging.info("training finished")
            model = get_model_by_id(model_id)
            logging.info("new model loaded")
            if output["error"] != []:
                if execution == "train":
                    # logging.error("error in model_entry")
                    model.status = "FAILED"
                    #record["other_info"] = {"error": py_to_java(str(output["error"]))}
                    logging.error(output["error"])
                    model.other_info = str({"error": output["error"]})
                    db.session.commit()
                return output

        # saving the results
        if execution == "train":
            y_pred_table_name = "ml_model.output_" + str(model_id)
            if train_inf is not None:
                save_gp(y_pred_table_name, train_inf)
                object_name = "{}/{}/{}_{}_{}".format(para_list["usr_id"],
                                                      model_type(para_list["algo"]),
                                                      para_list["algo"],
                                                      model_id, str(int(time.time())))
                model_saved_path = "s3a://ml-model/" + object_name
                print(model_saved_path)
                save_model(mlmodel, model, object_name)
                print("save model")
                logging.info("training data output saved")
                endtime = time.time()
                algotime = endtime - starttime
                record["other_info"]["message"] = py_to_java(str(output["message"]))
                record["other_info"]["out_table"] = y_pred_table_name
                record = cast_float(record)
                try:
                    eval(str(record["other_info"]))
                except Exception as e:
                    logging.error(e)
                    output["error"] = py_to_java(str(e))
                finally:
                    if output["error"]:
                        return output
                other_info = py_to_java(str(record["other_info"]))
                record["other_info"] = other_info

                # updating the model
                model.other_info = other_info
                model.algotime = algotime
                model.sparktime = 0
                model.runtime = round((algotime + 0.1), 2)
                model.status = "FINISHED"
                print(model_id, model.status)
                model.model_saved_path = model_saved_path
                output["status"] = 200
            else:
                logging.error("df_pred is None")
                model.status = "FAILED"
                model.other_info = {"error": "df_pred is None"}

        elif execution == "predict":
            # timestamp = int(time.time())
            try:
                # output["result"]["output_params"]["model_saved_path"] = model_saved_path
                output["result"]["output_params"]["target"] = target
                save_gp(target, df_pred)
                task_id = int(task_id)
                task = Task.query.filter_by(id=task_id).first()
                task = update_task_when_pred(task, df_pred, para_list, target)
                db.session.commit()
            except Exception as e:
                logging.error(e)
                output["error"] = py_to_java(str(e))
            finally:
                if output["error"]:
                    return output
        else:
            logging.info("wrong execution type")
        output["status"] = 200
        model.progress_id = 0
        db.session.commit()

    return output
